Skip to content

[Common] Enable NVFP4 2D block scaling in columnwise only#3027

Open
negvet wants to merge 4 commits into
NVIDIA:mainfrom
negvet:nvfp4_2d_colwise_only
Open

[Common] Enable NVFP4 2D block scaling in columnwise only#3027
negvet wants to merge 4 commits into
NVIDIA:mainfrom
negvet:nvfp4_2d_colwise_only

Conversation

@negvet
Copy link
Copy Markdown
Collaborator

@negvet negvet commented May 21, 2026

Description

Enabling 2D NVFP4 quantization in columnwise-only mode.
Needed by HybridQuantizer (PR #2817) for MXFP8 fwd + NVFP4 bwd on W.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

negvet and others added 2 commits May 21, 2026 17:35
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 21, 2026

Greptile Summary

This PR enables 2D NVFP4 quantization in columnwise-only mode, needed by HybridQuantizer for MXFP8-fwd + NVFP4-bwd on weights. Two separate code paths are updated: the optimized TMA-based kernel (quantize_transpose_nvfp4_2D_kernel) gains a RETURN_ROWWISE template boolean that gates out the rowwise scaling pass and data store at compile time, and the fallback blockwise kernel (block_scaled_1d_cast_transpose_kernel) gains a new "Step 2.5" that re-runs the load + local-amax + 2D warp/smem reduction from Step 2 when only the transposed output is requested, populating amax_smem for Step 3 without performing the rowwise scale/quantize/store writes.

  • Dispatch (quantize.cuh): use_optimized_kernel is extended to route BF16 + 32-aligned + columnwise-only + 2D requests to the TMA kernel instead of falling back to the blockwise path.
  • Optimized path (quantize_transpose_nvfp4.cuh): RETURN_ROWWISE template parameter added; scale_stride, scales_ptr, and tensor_map_output are all guarded so the kernel safely no-ops on their rowwise uses when RETURN_ROWWISE=false.
  • Fallback path (quantize_transpose_vector_blockwise_fp4.cu): New Step 2.5 correctly mirrors the 2D amax-reduction pass from Step 2, with matching __syncthreads() barriers, so amax_smem is fully populated for Step 3 in the columnwise-only 2D case.

Confidence Score: 5/5

The changes are well-scoped: new code paths are gated by constexpr template parameters or function-level preconditions, and the new Step 2.5 amax-only pass faithfully mirrors the existing 2D reduction without touching any existing output writes.

All access sites for rowwise pointers and tensor maps in the columnwise-only case are guarded at compile time (constexpr if). Synchronization in Step 2.5 is correct: amax_smem_red is written and synced before reduction, and amax_smem is synced before Step 3 reads it. The bitwise-equality test covers both the TMA kernel (BF16 + aligned) and the blockwise fallback (float32 + non-aligned), leaving minimal room for undetected regressions.

No files require special attention; the most complex change (Step 2.5 in the .cu file) is validated by the new test.

Important Files Changed

Filename Overview
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu Adds Step 2.5 amax-only pass for columnwise-only 2D mode; removes the early-return guard and function-level check that previously blocked this case. Synchronization with __syncthreads() is correct, amax_smem_red/amax_smem are written before read, and Step 3's amax_smem reads are unaffected for the 1D case.
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh Adds RETURN_ROWWISE template parameter to quantize_transpose_nvfp4_2D_kernel; gates rowwise scaling section and TMA cp_async rowwise store behind it. scale_stride=0 and uninitialized tensor_map_output for the columnwise-only path are harmless since all access sites are constexpr-guarded by RETURN_ROWWISE.
transformer_engine/common/cast/dispatch/quantize.cuh Extends use_optimized_kernel condition symmetrically in both fwd and bwd helpers to permit BF16 + aligned + columnwise-only 2D requests to reach the TMA kernel; comment accurately describes the routing.
tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py Adds bitwise-comparison test covering both the TMA kernel (BF16 + aligned shapes) and the blockwise fallback (float32 + non-aligned shapes). Valid-region slicing for the scale tensor and _rowwise_data=None assertion are correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["quantize_fwd/bwd_helper\n(quantize.cuh)"] --> B{use_optimized_kernel?}
    B -->|"BF16 + rows%32==0 + cols%32==0\n+ (has_data OR (has_colwise_data AND 2D))"| C["quantize_transpose\n(quantize_transpose_nvfp4.cuh)"]
    B -->|else| D["quantize_transpose_vector_blockwise_fp4\n(fallback kernel)"]
    C --> E{use_2d_quantization?}
    E -->|true| F["quantize_transpose_nvfp4_2D_kernel\nRETURN_ROWWISE=bool\nRETURN_TRANSPOSE=bool"]
    E -->|false| G["quantize_transpose_nvfp4_kernel (1D)\nguarded: return_rowwise||use_2d must be true"]
    F --> H{RETURN_ROWWISE?}
    H -->|true| I["Step: 2D amax + rowwise scale/quantize/store\n+ TMA cp_async rowwise data"]
    H -->|false| J["Step: 2D amax only\n(block_amax_matrix populated)"]
    I --> K{RETURN_TRANSPOSE?}
    J --> K
    K -->|true| L["Colwise scale/quantize/store\n+ TMA cp_async colwise data"]
    D --> M{kReturnIdentity\n&& kIs2DBlockScaling?}
    M -->|"kReturnIdentity=true"| N["Step 2: full rowwise pass\n(populates amax_smem)"]
    M -->|"!kReturnIdentity\n&& kIs2DBlockScaling=true"| O["Step 2.5 (NEW): amax-only pass\nload smem → amax_smem_red → amax_smem"]
    N --> P["Step 3: transpose cast+store\n(reads amax_smem for 2D)"]
    O --> P
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

}
}

// Step 2.5: 2D-amax-only pass for columnwise-only mode.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Step label collision with existing substep

The new outer-level block is named "Step 2.5" at line 576, but that same label is already used at line 522 for the "Write scale_inv" substep inside Step 2's loop (if constexpr (kReturnIdentity)). A future reader scanning the file will find two distinct "Step 2.5" sections with different semantics. Consider renaming the new block to something like "Step 2b" or "Step 2.5 (outer)" to distinguish it from the // Step 2.5: Write scale_inv substep inside the inner loop.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 21, 2026

This is just the fallback kernel being changed. Does the main kernel already support this?

Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from Oleg-Goncharov as a code owner June 1, 2026 11:34
@negvet
Copy link
Copy Markdown
Collaborator Author

negvet commented Jun 1, 2026

This is just the fallback kernel being changed. Does the main kernel already support this?

Thanks for the catch. The main kernel does not support if as well. Enabled in f7953dd

ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_ROWWISE) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already inside the if constexpr (RETURN_ROWWISE) scope (starting at line 1131), so it can be removed safely.

NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
NVTE_CHECK(return_rowwise || return_transpose,
"At least one of rowwise/columnwise NVFP4 output must be allocated.");
NVTE_CHECK(return_rowwise || use_2d_quantization,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing to read, especially if the kernel is extended in the future to support additional quantization schemes. It would be better to restrict the supported combinations explicitly, e.g.
NVTE_CHECK((return_transpose && use_2d_quantization) || (return_rowwise && !use_2d_quantization),


// Step 2.5: 2D-amax-only pass for columnwise-only mode.
// When only the transposed output is requested but 2D block scaling is enabled, the columnwise
// reads in Step 3 (line ~660 below) still need amax_smem populated. Re-run the load + local-amax
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment refers to line ~660, which is now line 637. Let’s maybe remove the line reference entirely to avoid confusion.

Copy link
Copy Markdown
Collaborator

@Oleg-Goncharov Oleg-Goncharov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add a corresponding C++ unit test to cover this, since this PR changes logic in the common part of the library

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants